{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5c58b13e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:No normalization for BCUT2D_MWHI\n",
      "WARNING:root:No normalization for BCUT2D_MWLOW\n",
      "WARNING:root:No normalization for BCUT2D_CHGHI\n",
      "WARNING:root:No normalization for BCUT2D_CHGLO\n",
      "WARNING:root:No normalization for BCUT2D_LOGPHI\n",
      "WARNING:root:No normalization for BCUT2D_LOGPLOW\n",
      "WARNING:root:No normalization for BCUT2D_MRHI\n",
      "WARNING:root:No normalization for BCUT2D_MRLOW\n"
     ]
    }
   ],
   "source": [
    "#加载DeepPurpose和TDC\n",
    "from DeepPurpose import utils, CompoundPred\n",
    "from DeepPurpose import DDI as models\n",
    "from tdc.multi_pred import DDI\n",
    "from tdc.chem_utils import MolConvert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dda3acb2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found local copy...\n",
      "Loading...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "#导入数据库\n",
    "X_drugs, X_drugs_, y = DDI(name = 'Drugbank').get_data(format = 'DeepPurpose')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cb14606a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#选用编码器\n",
    "drug_encoding ='MPNN'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9843dc42",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drug Drug Interaction Prediction Mode...\n",
      "in total: 215037 drug-drug pairs\n",
      "encoding drug...\n",
      "unique drugs: 1704\n",
      "encoding drug...\n",
      "unique drugs: 1704\n",
      "Done.\n"
     ]
    }
   ],
   "source": [
    "\n",
    "train, val, test = utils.data_process (X_drug = X_drugs, X_drug_ = X_drugs_, y = y, \n",
    "\t\t\t    drug_encoding = drug_encoding,\n",
    "\t\t\t    split_method='random',frac=[0.6,0.2,0.2], \n",
    "\t\t\t    random_seed = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "49506056",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SMILES 1</th>\n",
       "      <th>SMILES 2</th>\n",
       "      <th>Label</th>\n",
       "      <th>drug_encoding_1</th>\n",
       "      <th>drug_encoding_2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CC1=CC2=CC3=C(OC(=O)C=C3C)C(C)=C2O1</td>\n",
       "      <td>COC(=O)CCC1=C2NC(\\C=C3/N=C(/C=C4\\N\\C(=C/C5=N/C...</td>\n",
       "      <td>1</td>\n",
       "      <td>[[[tensor(1.), tensor(0.), tensor(0.), tensor(...</td>\n",
       "      <td>[[[tensor(1.), tensor(0.), tensor(0.), tensor(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>CC(C(O)=O)C1=CC2=C(C=C1)C1=C(N2)C=CC(Cl)=C1</td>\n",
       "      <td>COC(=O)CCC1=C2NC(\\C=C3/N=C(/C=C4\\N\\C(=C/C5=N/C...</td>\n",
       "      <td>1</td>\n",
       "      <td>[[[tensor(1.), tensor(0.), tensor(0.), tensor(...</td>\n",
       "      <td>[[[tensor(1.), tensor(0.), tensor(0.), tensor(...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                      SMILES 1  \\\n",
       "0          CC1=CC2=CC3=C(OC(=O)C=C3C)C(C)=C2O1   \n",
       "1  CC(C(O)=O)C1=CC2=C(C=C1)C1=C(N2)C=CC(Cl)=C1   \n",
       "\n",
       "                                            SMILES 2  Label  \\\n",
       "0  COC(=O)CCC1=C2NC(\\C=C3/N=C(/C=C4\\N\\C(=C/C5=N/C...      1   \n",
       "1  COC(=O)CCC1=C2NC(\\C=C3/N=C(/C=C4\\N\\C(=C/C5=N/C...      1   \n",
       "\n",
       "                                     drug_encoding_1  \\\n",
       "0  [[[tensor(1.), tensor(0.), tensor(0.), tensor(...   \n",
       "1  [[[tensor(1.), tensor(0.), tensor(0.), tensor(...   \n",
       "\n",
       "                                     drug_encoding_2  \n",
       "0  [[[tensor(1.), tensor(0.), tensor(0.), tensor(...  \n",
       "1  [[[tensor(1.), tensor(0.), tensor(0.), tensor(...  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0fe35f37",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SMILES 1</th>\n",
       "      <th>SMILES 2</th>\n",
       "      <th>Label</th>\n",
       "      <th>drug_encoding_1</th>\n",
       "      <th>drug_encoding_2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>COC1=CC2=C(C=C1)C=C(CCC(C)=O)C=C2</td>\n",
       "      <td>[Ca++].[Ca++].[Ca++].OC(CC([O-])=O)(CC([O-])=O...</td>\n",
       "      <td>0</td>\n",
       "      <td>[[[tensor(1.), tensor(0.), tensor(0.), tensor(...</td>\n",
       "      <td>[[[tensor(0.), tensor(0.), tensor(0.), tensor(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>OCCOCCN1CCN(CC1)C1=NC2=CC=CC=C2SC2=CC=CC=C12</td>\n",
       "      <td>CN1CCN(CC1)C(C1=CC=CC=C1)C1=CC=CC=C1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[[tensor(0.), tensor(0.), tensor(1.), tensor(...</td>\n",
       "      <td>[[[tensor(1.), tensor(0.), tensor(0.), tensor(...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       SMILES 1  \\\n",
       "0             COC1=CC2=C(C=C1)C=C(CCC(C)=O)C=C2   \n",
       "1  OCCOCCN1CCN(CC1)C1=NC2=CC=CC=C2SC2=CC=CC=C12   \n",
       "\n",
       "                                            SMILES 2  Label  \\\n",
       "0  [Ca++].[Ca++].[Ca++].OC(CC([O-])=O)(CC([O-])=O...      0   \n",
       "1               CN1CCN(CC1)C(C1=CC=CC=C1)C1=CC=CC=C1      1   \n",
       "\n",
       "                                     drug_encoding_1  \\\n",
       "0  [[[tensor(1.), tensor(0.), tensor(0.), tensor(...   \n",
       "1  [[[tensor(0.), tensor(0.), tensor(1.), tensor(...   \n",
       "\n",
       "                                     drug_encoding_2  \n",
       "0  [[[tensor(0.), tensor(0.), tensor(0.), tensor(...  \n",
       "1  [[[tensor(1.), tensor(0.), tensor(0.), tensor(...  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "17a3364a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(129022, 5)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.shape\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c9570ca2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(43008, 5)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9549886d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(43007, 5)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "721fb516",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = utils.generate_config(drug_encoding = drug_encoding, \n",
    "                         cls_hidden_dims = [512], \n",
    "                         train_epoch =2, \n",
    "                         LR = 0.001, \n",
    "                         batch_size = 256,\n",
    "                        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fb4b1f1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.model_initialize(**config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "88643e7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Let's use 1 GPU!\n",
      "--- Data Preparation ---\n",
      "--- Go for Training ---\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "only one element tensors can be converted to Python scalars",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_12796\\296990116.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;31m# Training\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mval\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m~\\DeepPurpose\\DeepPurpose\\DDI.py\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(self, train, val, test, verbose)\u001b[0m\n\u001b[0;32m    256\u001b[0m                                         \u001b[0mv_p\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mv_p\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    257\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 258\u001b[1;33m                                 \u001b[0mscore\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mv_d\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mv_p\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    259\u001b[0m                                 \u001b[0mlabel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mVariable\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlabel\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    260\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\ana\\envs\\01gpu\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\DeepPurpose\\DeepPurpose\\DDI.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, v_D, v_P)\u001b[0m\n\u001b[0;32m     42\u001b[0m                 \u001b[1;31m# each encoding\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     43\u001b[0m                 \u001b[0mv_D\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel_drug\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mv_D\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m                 \u001b[0mv_P\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel_drug\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mv_P\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     45\u001b[0m                 \u001b[1;31m# concatenate and classify\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     46\u001b[0m                 \u001b[0mv_f\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mv_D\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mv_P\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mE:\\ana\\envs\\01gpu\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\DeepPurpose\\DeepPurpose\\encoders.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, feature)\u001b[0m\n\u001b[0;32m    265\u001b[0m                 \u001b[0mfatoms_lst\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfbonds_lst\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0magraph_lst\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbgraph_lst\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    266\u001b[0m                 \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mN_atoms_bond\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 267\u001b[1;33m                         \u001b[0matom_num\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mN_atoms_bond\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    268\u001b[0m                         \u001b[0mbond_num\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mN_atoms_bond\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    269\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mValueError\u001b[0m: only one element tensors can be converted to Python scalars"
     ]
    }
   ],
   "source": [
    "# Training\n",
    "model.train(train, val, test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75d75a5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_model('./4_model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "697d37ca",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ef54e43",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
